{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 稠密连接网络（DenseNet）\n",
    "\n",
    "ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个：稠密连接网络（DenseNet） [1]。 它与ResNet的主要区别如图5.10所示。\n",
    "\n",
    "![ResNet（左）与DenseNet（右）在跨层连接上的主要区别：使用相加和使用连结](../img/densenet.svg)\n",
    "\n",
    "图5.10中将部分前后相邻的运算抽象为模块$A$和模块$B$。与ResNet的主要区别在于，DenseNet里模块$B$的输出不是像ResNet那样和模块$A$的输出相加，而是在通道维上连结。这样模块$A$的输出可以直接传入模块$B$后面的层。在这个设计里，模块$A$直接跟模块$B$后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。\n",
    "\n",
    "DenseNet的主要构建模块是稠密块（dense block）和过渡层（transition layer）。前者定义了输入和输出是如何连结的，后者则用来控制通道数，使之不过大。\n",
    "\n",
    "\n",
    "## 稠密块\n",
    "\n",
    "DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构（参见上一节的练习），我们首先在`conv_block`函数里实现这个结构。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "1"
    }
   },
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "from mxnet import gluon, init, nd\n",
    "from mxnet.gluon import nn\n",
    "\n",
    "def conv_block(num_channels):\n",
    "    blk = nn.Sequential()\n",
    "    blk.add(nn.BatchNorm(), nn.Activation('relu'),\n",
    "            nn.Conv2D(num_channels, kernel_size=3, padding=1))\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "稠密块由多个`conv_block`组成，每块使用相同的输出通道数。但在前向计算时，我们将每块的输入和输出在通道维上连结。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "class DenseBlock(nn.Block):\n",
    "    def __init__(self, num_convs, num_channels, **kwargs):\n",
    "        super(DenseBlock, self).__init__(**kwargs)\n",
    "        self.net = nn.Sequential()\n",
    "        for _ in range(num_convs):\n",
    "            self.net.add(conv_block(num_channels))\n",
    "\n",
    "    def forward(self, X):\n",
    "        for blk in self.net:\n",
    "            Y = blk(X)\n",
    "            X = nd.concat(X, Y, dim=1)  # 在通道维上将输入和输出连结\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在下面的例子中，我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时，我们会得到通道数为$3+2\\times 10=23$的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长，因此也被称为增长率（growth rate）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "8"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 23, 8, 8)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = DenseBlock(2, 10)\n",
    "blk.initialize()\n",
    "X = nd.random.uniform(shape=(4, 3, 8, 8))\n",
    "Y = blk(X)\n",
    "Y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 过渡层\n",
    "\n",
    "由于每个稠密块都会带来通道数的增加，使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过$1\\times1$卷积层来减小通道数，并使用步幅为2的平均池化层减半高和宽，从而进一步降低模型复杂度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "def transition_block(num_channels):\n",
    "    blk = nn.Sequential()\n",
    "    blk.add(nn.BatchNorm(), nn.Activation('relu'),\n",
    "            nn.Conv2D(num_channels, kernel_size=1),\n",
    "            nn.AvgPool2D(pool_size=2, strides=2))\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10，高和宽均减半。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 10, 4, 4)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = transition_block(10)\n",
    "blk.initialize()\n",
    "blk(Y).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DenseNet模型\n",
    "\n",
    "我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential()\n",
    "net.add(nn.Conv2D(64, kernel_size=7, strides=2, padding=3),\n",
    "        nn.BatchNorm(), nn.Activation('relu'),\n",
    "        nn.MaxPool2D(pool_size=3, strides=2, padding=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "类似于ResNet接下来使用的4个残差块，DenseNet使用的是4个稠密块。同ResNet一样，我们可以设置每个稠密块使用多少个卷积层。这里我们设成4，从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数（即增长率）设为32，所以每个稠密块将增加128个通道。\n",
    "\n",
    "ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽，并减半通道数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "num_channels, growth_rate = 64, 32  # num_channels为当前的通道数\n",
    "num_convs_in_dense_blocks = [4, 4, 4, 4]\n",
    "\n",
    "for i, num_convs in enumerate(num_convs_in_dense_blocks):\n",
    "    net.add(DenseBlock(num_convs, growth_rate))\n",
    "    # 上一个稠密块的输出通道数\n",
    "    num_channels += num_convs * growth_rate\n",
    "    # 在稠密块之间加入通道数减半的过渡层\n",
    "    if i != len(num_convs_in_dense_blocks) - 1:\n",
    "        num_channels //= 2\n",
    "        net.add(transition_block(num_channels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "同ResNet一样，最后接上全局池化层和全连接层来输出。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.add(nn.BatchNorm(), nn.Activation('relu'), nn.GlobalAvgPool2D(),\n",
    "        nn.Dense(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 获取数据并训练模型\n",
    "\n",
    "由于这里使用了比较深的网络，本节里我们将输入高和宽从224降到96来简化计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on gpu(0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.5374, train acc 0.810, test acc 0.878, time 13.5 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, loss 0.3189, train acc 0.882, test acc 0.890, time 12.2 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, loss 0.2677, train acc 0.903, test acc 0.885, time 12.2 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, loss 0.2358, train acc 0.913, test acc 0.881, time 12.2 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, loss 0.2136, train acc 0.922, test acc 0.857, time 12.3 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs, batch_size, ctx = 0.1, 5, 256, d2l.try_gpu()\n",
    "net.initialize(ctx=ctx, init=init.Xavier())\n",
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx,\n",
    "              num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 在跨层连接上，不同于ResNet中将输入与输出相加，DenseNet在通道维上连结输入与输出。\n",
    "* DenseNet的主要构建模块是稠密块和过渡层。\n",
    "\n",
    "## 练习\n",
    "\n",
    "* DenseNet论文中提到的一个优点是模型参数比ResNet的更小，这是为什么？\n",
    "* DenseNet被人诟病的一个问题是内存或显存消耗过多。真的会这样吗？可以把输入形状换成$224\\times 224$，来看看实际的消耗。\n",
    "* 实现DenseNet论文中的表1提出的不同版本的DenseNet [1]。\n",
    "\n",
    "\n",
    "\n",
    "## 参考文献\n",
    "\n",
    "[1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/1664)\n",
    "\n",
    "![](../img/qr_densenet.svg)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}